from math import log
import operator
import pickle
import os
import numpy as np

def debug(value_name,value):
    print("debuging for %s" % value_name)
    print(value)

# feature map  and wind_label

def loadDateset():
    with open(r'/workspace/wine/wine.data') as f:
        wine = [eaxm.strip().split(',') for  eaxm in f.readlines()]

        #for i in range(len(wine)):
        #    wine[i] = list(map(float,wine[i]))

        wine = np.array(wine)
        wine_label = wine[...,:1]
        wine_data = wine[...,1:]

        # get the map of wine_feature
        featLabels = []

        for i in range(len(wine_data)):
            #print(i)
            featLabels.append(i)

        #
        wine_data = np.concatenate((wine_data,wine_label),axis=1)
        #  这里的label需要做一定的修改　需要的label是属性对应的字典
    return wine_data,featLabels

#  wine_data = dateset[:-1] wine_label = dateset[-1:]
def informationEntropy(dataSet):
    m = len(dataSet)
    labelMap = {}
    for wine in dataSet:
        nowLabel = wine[-1]
        if nowLabel not in labelMap.keys():
            labelMap[nowLabel] = 0
        labelMap[nowLabel] += 1
    shannoEnt = 0.0
    for key in labelMap.keys():
        prop = float(labelMap[key]/m)
        shannoEnt -= prop*(log(prop,2))

    return shannoEnt

# split the subDataSet  Improve reusability
def splitDataSet(dataSet,axis,feature):
    subDataSet = []
    # date type
    for featVec in dataSet:
        if(featVec[axis] == feature):
            reduceVec = featVec[:axis]
            if(isinstance(reduceVec,np.ndarray)):
                reduceVec = np.ndarray.tolist(reduceVec)
            reduceVec.extend(featVec[axis+1:])
            subDataSet.append(reduceVec)
    return subDataSet

# choose the best Feature to split
def chooseFeature(dataSet):
    numFeature = len(dataSet[0])-1
    baseEntorpy = informationEntropy(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1

    for i in range(numFeature):
        #valueList = wine_data[:,i:i+1]
        valueList = [value[i] for value in dataSet]

        # debug
        # print("valueList is:")
        # print(len(valueList))

        uniqueVals = set(valueList)
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet,i,value)

            #debug
            #print("subDataSet is :")
            #print(subDataSet)
            #print(len(subDataSet[0]))

            # 数值部分要注意
            prop = len(subDataSet)/float(len(dataSet))
            newEntropy += prop*informationEntropy(subDataSet)

        infoGain = baseEntorpy - newEntropy
        if(infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i

    return bestFeature

def majorityCnt(classList):
    classMap = {}
    for vote in classList:
        if vote not in classMap.keys():
            classMap[vote] = 0
        classMap[vote] += 1

    #tempMap = sorted(classMap.items(),key = operator.itemgetter(1),reverse = True)
    tempMap = sorted(classMap.items(), key=lambda  x:x[1], reverse=True)
    return tempMap[0][0]


# labels for map of Feature
def createTree(dataSet,Featlabels):
    classList = [example[-1] for example in dataSet]
    # if all of the attribute of classList is same

    if(classList.count(classList[0])) == len(classList):
        #print("all is same")
        return classList[0]
    # print("debug after")
    # feature is empty
    if len(dataSet[0]) == 1:
        print("len is zero")
        return majorityCnt(classList)
    # print("debug pre")
    bestFeat = chooseFeature(dataSet)
    #debug
    #print("debug")
    #print(bestFeat)

    bestFeatLabel = Featlabels[bestFeat]
    # print(bestFeatLabel)
    # python tree use dict for index of feature to build the tree
    myTree = {bestFeatLabel:{}}

    # del redundant label
    del(Featlabels[bestFeat])

    valueList = [example[bestFeat] for example in dataSet]
    uniqueVals = set(valueList)

    # print(uniqueVals)
    # 取值都一样的话就没有必要继续划分
    if(len(uniqueVals) == 1):
        return majorityCnt(dataSet)

    for value in uniqueVals:
        #if(bestFeat == 6):
        #    print(value)
        subFeatLabels = Featlabels[:]
        # print(sublabels)
        subdataSet = splitDataSet(dataSet,bestFeat,value)

        if(bestFeatLabel == 6 and value == '3.06'):
            #print("debuging ")
            myTree[bestFeatLabel][value] = createTree(subdataSet, subFeatLabels)
            #print(myTree[bestFeatLabel][value])
            #print("len of build")
            #print(len(uniqueVals))
        #    print(value)
        else:
            myTree[bestFeatLabel][value] = createTree(subdataSet,subFeatLabels)


    return myTree



# classity fuction   featLabel and testVes is used to get featvalue of test
def classify(inputTree,featLabels,testVec):
    # get the node
    nowNode = list(inputTree.keys())[0]

    # debug
    #debug(nowNode)
    # print(featLabels)
    featIndex = featLabels.index(nowNode)

    # print(featIndex)
    #find the value of  testVec in feature
    keyValue = testVec[featIndex]

    #print("len of input")
    #print(len(inputTree[nowNode].keys()))
    keyValue = str(keyValue)
    subTree = inputTree[nowNode][keyValue]
    if(isinstance(subTree,dict)):
        classLabel = classify(subTree,featLabels,testVec)
    else:
        classLabel = subTree

    return classLabel


if __name__ == '__main__':
    wine_data, featLabels = loadDateset()
    #print(featLabels)
    #print(wine_data)
    myTree = createTree(wine_data,featLabels.copy())

    #print(type(myTree))
    # the type of value
    test = [14.23,1.71,2.43,15.6,127,2.8,3.06,.28,2.29,5.64,1.04,3.92,1065]
    #print(featLabels)
    print(classify(myTree,featLabels,test))